{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2025, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}
from __future__ import annotations

import typing as T

"""
Any object with a valid __array_interface__

https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.interface.html#__array_interface
"""
Array = T.Any

"""
Any object with a valid __cuda_array_interface__

https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.interface.html#__array_interface
"""
CudaArray = T.Any

{% if solver is not none %}
class SolverParams:
    solver_iter_max: int
    pcg_iter_max: int

    diag_init: float
    diag_scaling_up: float
    diag_scaling_down: float
    diag_exit_value: float

    solver_rel_decrease_min: float
    score_exit_value: float

    pcg_rel_decrease_min: float
    pcg_rel_error_exit: float
    pcg_rel_score_exit: float

class {{solver.struct_name}}:
    def __init__(self, params: SolverParams,
                 *,
                 {% for thing in solver.size_contributors %}
                 {{num_arg_key(thing)}}: int = 0,
                 {% endfor %}
    ): ...

    def set_params(self, params: SolverParams) -> None:
        """
        Set the solver parameters.
        """

    def solve(self, print_progress: bool = False) -> None:
        """
        Run the solver.
        """

    def finish_indices(self) -> None:
        """
        Finish the indices.

        This function has to be called after all indices are set and before the solve function is called.
        """

    def get_allocation_size(self) -> int:
        """
        Get the number of allocated bytes.
        """

    {% for nodetype in solver.node_types %}
    def set_{{nodetype.__name__}}_nodes_from_stacked_host(self, stacked_data: Array, offset: int = 0) -> None:
        """
        Set the current value for the {{nodetype.__name__}} nodes from the stacked host data.

        The offset can be used to start writing at a specific index.
        """

    def set_{{nodetype.__name__}}_nodes_from_stacked_device(self, stacked_data: CudaArray, offset: int = 0) -> None:
        """
        Set the current value for the {{nodetype.__name__}} nodes from the stacked device data.

        The offset can be used to start writing at a specific index.
        """

    def get_{{nodetype.__name__}}_nodes_to_stacked_host(self, out_stacked_data: Array, offset: int = 0) -> None:
        """
        Read the current value for the {{nodetype.__name__}} nodes into the stacked output host data.

        The offset can be used to start reading from a specific index.
        """

    def get_{{nodetype.__name__}}_nodes_to_stacked_device(self, out_stacked_data: CudaArray, offset: int = 0) -> None:
        """
        Read the current value for the {{nodetype.__name__}} nodes into the stacked output device data.

        The offset can be used to start reading from a specific index.
        """

    def set_{{nodetype.__name__}}_num(self, num: int) -> None:
        """
        Set the current number of active nodes of type {{nodetype.__name__}}.

        The value is set during initialization and this function is only needed if you want to change
        the problem between optimization runs. This is work in progress and can have performance impacts.
        """
    {% endfor %}

    {% for factor in solver.factors %}
    {% for argname, argtype in factor.node_arg_types.items() %}
    def set_{{factor.name}}_{{argname}}_indices_from_host(self, indices: Array) -> None:
        """
        Set the indices for the {{argname}} argument for the {{factor.name}} factor from host.
        """

    def set_{{factor.name}}_{{argname}}_indices_from_device(self, indices: CudaArray) -> None:
        """
        Set the indices for the {{argname}} argument for the {{factor.name}} factor from device.
        """
    {% endfor %}

    {% for argname, argtype in factor.const_arg_types.items() %}
    def set_{{factor.name}}_{{argname}}_data_from_stacked_host(self, stacked_data: Array, offset: int = 0) -> None:
        """
        Set the values for the {{argname}} consts {{factor.name}} factor from stacked host data.

        The offset can be used to start writing from a specific index.
        """

    def set_{{factor.name}}_{{argname}}_data_from_stacked_device(self, stacked_data: CudaArray, offset: int = 0) -> None:
        """
        Set the values for the {{argname}} consts {{factor.name}} factor from stacked device data.

        The offset can be used to start writing from a specific index.
        """
    {% endfor %}

    def set_{{factor.name}}_num(self, num: int) -> None:
        """
        Set the current number of {{factor.name}} factors.

        The value is set during initialization and this function is only needed if you want to change
        the problem between optimization runs. This is work in progress and can have performance impacts.
        """
    {% endfor %}
{% endif %}

{% for typ in caslib.exposed_types %}
def {{typ.__name__}}_stacked_to_caspar(stacked_data: CudaArray, out_cas_data: CudaArray) -> None:
    """
    Convert the stacked {{typ.__name__}} data to the caspar data format.
    """

def {{typ.__name__}}_caspar_to_stacked(caspar_data: CudaArray, out_stacked_data: CudaArray) -> None:
    """
    Convert the caspar {{typ.__name__}} data to the stacked data format.
    """

{% endfor %}

def shared_indices(indices: CudaArray, out_shared: CudaArray) -> None:
    """
    Calculate shared indices from the indices.
    """

{% for kernel in caslib.kernels %}
def {{kernel.name}}(
    {% for accessor in kernel.accessors %}
    {% for arg_name, arg_type in accessor.py_sig.items() %}
    {{arg_name}}: {{ arg_type if arg_type in ["int", "float"] else "CudaArray"}},
    {% endfor %}
    {% endfor %}
    problem_size: int
) -> None: ...

{% endfor %}
